#!/usr/bin/python3
# ******************************************************************************
# Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
# licensed under the Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#     http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN 'AS IS' BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR
# PURPOSE.
# See the Mulan PSL v2 for more details.
# ******************************************************************************/
"""
Time:
Author:
Description: vulnerability related database operation
"""
from time import time
import math
import json
from collections import defaultdict
from sqlalchemy.exc import SQLAlchemyError
from elasticsearch import ElasticsearchException

from aops_utils.log.log import LOGGER
from aops_utils.database.helper import sort_and_page, judge_return_code
from aops_utils.database.proxy import MysqlProxy, ElasticsearchProxy
from aops_utils.restful.status import DATABASE_DELETE_ERROR, DATABASE_INSERT_ERROR, NO_DATA, \
    DATABASE_QUERY_ERROR, DATABASE_UPDATE_ERROR, SUCCEED, PARAM_ERROR, SERVER_ERROR
from aops_utils.database.table import Host, User
from cve_manager.database.table import Cve, Task, TaskCveHostAssociation, TaskHostRepoAssociation, \
    CveTaskAssociation, CveHostAssociation, CveAffectedPkgs, CveUserAssociation
from cve_manager.conf.constant import TASK_INDEX
from cve_manager.function.customize_exception import EsOperationError


task_types = ["cve", "repo"]


class TaskMysqlProxy(MysqlProxy):
    """
    Task related mysql table operation
    """

    def get_scan_host_info(self, username, host_list):
        """
        Query host info according to host id list.

        Args:
            username (str): user name
            host_list (list): host id list, can be empty

        Returns:
            list: host info, e.g.
                [
                    {
                        "host_id": "",
                        "host_ip": "",
                        "host_name": "",
                        "status": ""
                    }
                ]
        """
        result = []
        try:
            result = self._get_host_info(username, host_list)
            LOGGER.debug("Finished getting host info.")
            return result
        except SQLAlchemyError as error:
            LOGGER.error(error)
            LOGGER.error("Getting host info failed due to internal error.")
            return result

    def _get_host_info(self, username, host_list):
        """
        get info of the host id in host_list. If host list is empty, query all hosts
        """
        filters = {Host.user == username}
        if host_list:
            filters.add(Host.host_id.in_(host_list))

        info_query = self.session.query(Host.host_id, Host.host_name, Host.public_ip, Host.status) \
            .filter(*filters)

        info_list = []
        for row in info_query:
            host_info = {
                "host_id": row.host_id,
                "host_name": row.host_name,
                "host_ip": row.public_ip,
                "status": row.status
            }
            info_list.append(host_info)
        return info_list

    def get_total_host_info(self):
        """
        Get the whole host info of each user.
        Args:

        Returns:
            int: status code
            dict: query result
        """
        temp_res = {}
        result = {"host_infos": temp_res}

        try:
            users = self.session.query(User).all()
            for user in users:
                name = user.username
                temp_res[name] = []
                for host in user.hosts:
                    host_info = {
                        "host_id": host.host_id,
                        "host_name": host.host_name,
                        "host_ip": host.public_ip,
                        "status": host.status
                    }
                    temp_res[name].append(host_info)
            return SUCCEED, result
        except SQLAlchemyError as error:
            LOGGER.error(error)
            LOGGER.error("query host basic info fail")
            return DATABASE_QUERY_ERROR, result

    def init_host_scan(self, username, host_list):
        """
        When the host need to be scanned, init the status to 'scanning',
        and update the last scan time to current time.
        Notice, if one host id doesn't exist, all hosts will not be scanned
        Args:
            username (str): user name
            host_list (list): host id list, if empty, scan all hosts
        Returns:
            int: status code
        """
        try:
            status_code = self._update_host_scan("init", host_list, username)
            self.session.commit()
            LOGGER.debug("Finished init host scan status.")
            return status_code
        except SQLAlchemyError as error:
            self.session.rollback()
            LOGGER.error(error)
            LOGGER.error("Init host scan status failed due to internal error.")
            return DATABASE_UPDATE_ERROR

    def update_scan_status(self, host_list):
        """
        Every time a host or host list have been scanned, update the status to "done"

        Args:
            host_list (list): host id list

        Returns:
            int: status code
        """
        try:
            status_code = self._update_host_scan("finish", host_list)
            self.session.commit()
            LOGGER.debug("Finished updating host status after scanned.")
            return status_code
        except SQLAlchemyError as error:
            self.session.rollback()
            LOGGER.error(error)
            LOGGER.error("Updating host status after scanned failed due to internal error.")
            return DATABASE_UPDATE_ERROR

    def _update_host_scan(self, update_type, host_list, username=None):
        """
        Update hosts scan status and last_scan time
        Args:
            update_type (str): 'init' or 'finish'
            host_list (list): host id list
            username (str): user name
        Returns:

        """
        if update_type == "init":
            update_dict = {Host.status: "scanning", Host.last_scan: int(time())}
        elif update_type == "finish":
            update_dict = {Host.status: "done"}
        else:
            LOGGER.error("Given host scan update type '%s' is not in default type list "
                         "['init', 'finish']." % update_type)
            return SERVER_ERROR

        host_scan_query = self._query_scan_status_and_time(host_list, username)
        succeed_list = [row.host_id for row in host_scan_query]
        fail_list = set(host_list) - set(succeed_list)
        if fail_list:
            LOGGER.debug("No data found when setting the status of host: %s." % fail_list)
            if update_type == "init":
                return NO_DATA

        # update() is not applicable to 'in_' method without synchronize_session=False
        host_scan_query.update(update_dict, synchronize_session=False)
        return SUCCEED

    def _query_scan_status_and_time(self, host_list, username):
        """
        query host status and last_scan data of specific user
        Args:
            host_list (list): host id list, when empty, query all hosts
            username (str/None): user name
        Returns:
            sqlalchemy.orm.query.Query
        """
        filters = set()
        if host_list:
            filters.add(Host.host_id.in_(host_list))
        if username:
            filters.add(Host.user == username)

        hosts_status_query = self.session.query(Host.host_id, Host.status, Host.last_scan) \
            .filter(*filters)
        return hosts_status_query

    def save_scan_result(self, username, host_dict):
        """
        Save the scanned result to database.
        If one host id doesn't in Host table, then error will be raised and all insertion will
        be rolled back. So make sure Delete Host is forbidden during the scan.

        Args:
            username (str): username
            host_dict (dict): e.g.
                {
                    "id1": ["cve1", "cve2"],
                    "id2": [],
                }
        Returns:
            int
        """
        try:
            status_code = self._save_scan_result(username, host_dict)
            self.session.commit()
            LOGGER.debug("Finished saving scan result.")
            return status_code
        except SQLAlchemyError as error:
            self.session.rollback()
            LOGGER.error(error)
            LOGGER.error("Saving scan result failed due to internal error.")
            return DATABASE_INSERT_ERROR

    def _save_scan_result(self, username, host_dict):
        """
        Save host and cve's relationship into CveHostAssociation table.
        Delete the scanned hosts' previous cve record first, then add new cve record if
        cve list is not empty.

        Args:
            username (str): username
            host_dict (dict): record scanned hosts' cve list

        Returns:
            int
        """
        host_list = list(host_dict.keys())

        self.session.query(CveHostAssociation) \
            .filter(CveHostAssociation.host_id.in_(host_list)) \
            .delete(synchronize_session=False)

        exist_cve_query = self.session.query(Cve.cve_id)
        exist_cve_set = {row.cve_id for row in exist_cve_query}

        cve_host_rows = []
        valid_cve_set = set()
        for host_id, cve_list in host_dict.items():
            for cve_id in cve_list:
                if cve_id in exist_cve_set:
                    row = {"host_id": host_id, "cve_id": cve_id}
                    cve_host_rows.append(row)
                    valid_cve_set.add(cve_id)
                else:
                    LOGGER.debug("Cve '%s' in Scan result cannot be recorded because its data has "
                                 "not been imported yet." % cve_id)

        self.session.bulk_insert_mappings(CveHostAssociation, cve_host_rows)
        self._update_user_cve_status(username, valid_cve_set)
        return SUCCEED

    def _update_user_cve_status(self, username, cve_set):
        """
        update CveUserAssociation table, add new cve's record. If a cve doesn't exist in all
        hosts, still preserve it in the table
        Args:
            username (str): user name
            cve_set (set): the cve set to be added into CveUserAssociation table

        Returns:
            None
        """
        exist_cve_query = self.session.query(CveUserAssociation.cve_id) \
            .filter(CveUserAssociation.user_name == username)
        exist_cve = [row.cve_id for row in exist_cve_query]

        new_cve_list = list(cve_set-set(exist_cve))
        user_cve_rows = []
        for cve_id in new_cve_list:
            user_cve_rows.append({"cve_id": cve_id, "user_name": username,
                                  "status": "not reviewed"})
        self.session.bulk_insert_mappings(CveUserAssociation, user_cve_rows)

    def get_task_list(self, data):
        """
        Get the task list.
        Args:
            data (dict): parameter, e.g.
                {
                    "username": "admin",
                    "sort": "host_num",
                    "direction": "asc",
                    "page": 1,
                    "per_page": 10,
                    "filter": {
                        "task_name": "task2",
                        "task_type": ["repo"]
                    }
                }
        Returns:
            int: status code
            dict: query result. e.g.
                {
                    "total_count": 1,
                    "total_page": 1,
                    "result": [
                        {
                            "task_id": "id1",
                            "task_name": "task1",
                            "task_type": "cve",
                            "description": "a long description",
                            "host_num": 12,
                            "create_time": 1111111111
                        }
                    ]
                }
        """
        result = {}
        try:
            result = self._get_processed_task_list(data)
            LOGGER.debug("Finished getting task list.")
            return SUCCEED, result
        except SQLAlchemyError as error:
            LOGGER.error(error)
            LOGGER.error("Getting task list failed due to internal error.")
            return DATABASE_QUERY_ERROR, result

    def _get_processed_task_list(self, data):
        """
        Get sorted, paged and filtered task list.

        Args:
            data(dict): sort, page and filter info

        Returns:
            dict
        """
        result = {
            "total_count": 0,
            "total_page": 0,
            "result": []
        }

        filters = self._get_task_list_filters(data.get("filter"))
        task_list_query = self._query_task_list(data["username"], filters)

        total_count = len(task_list_query.all())
        if not total_count:
            return result

        sort_column = getattr(Task, data.get("sort")) if "sort" in data else None
        direction, page, per_page = data.get('direction'), data.get('page'), data.get('per_page')

        processed_query, total_page = sort_and_page(task_list_query, sort_column,
                                                    direction, per_page, page)

        result['result'] = self._task_list_row2dict(processed_query)
        result['total_page'] = total_page
        result['total_count'] = total_count

        return result

    def _query_task_list(self, username, filters):
        """
        query needed task list
        Args:
            username (str): user name of the request
            filters (set): filter given by user

        Returns:
            sqlalchemy.orm.query.Query
        """
        task_list_query = self.session.query(Task.task_id, Task.task_name, Task.task_type,
                                             Task.description, Task.host_num, Task.create_time) \
            .filter(Task.username == username) \
            .filter(*filters)
        return task_list_query

    @staticmethod
    def _task_list_row2dict(rows):
        result = []
        for row in rows:
            task_info = {
                "task_id": row.task_id,
                "task_name": row.task_name,
                "task_type": row.task_type,
                "description": row.description,
                "host_num": row.host_num,
                "create_time": row.create_time
            }
            result.append(task_info)
        return result

    @staticmethod
    def _get_task_list_filters(filter_dict):
        """
        Generate filters

        Args:
            filter_dict(dict): filter dict to filter cve list, e.g.
                {
                    "task_name": "task2",
                    "task_type": ["cve", "repo"]
                }

        Returns:
            set
        """
        filters = set()
        if not filter_dict:
            return filters

        if filter_dict.get("task_name"):
            filters.add(Task.task_name.like("%" + filter_dict["task_name"] + "%"))
        if filter_dict.get("task_type"):
            filters.add(Task.task_type.in_(filter_dict["task_type"]))
        return filters

    def get_task_progress(self, data):
        """
        Get the task progress.
        Args:
            data (dict): parameter, e.g.
                {
                    "task_list": ["task1", "task2"],
                    "username": "admin"
                }
        Returns:
            int: status code
            dict: query result. e.g.
                {
                    "result": {
                        "task1": {
                            "succeed": 1,
                            "fail": 0,
                            "running": 11,
                            "unknown": 0
                        },
                        "task2": {
                            "succeed": 12,
                            "fail": 0,
                            "running": 0,
                            "unknown": 0
                        }
                    }
                }
        """
        result = {}
        try:
            status_code, result = self._get_processed_task_progress(data)
            LOGGER.debug("Finished getting task progress.")
            return status_code, result
        except (SQLAlchemyError, KeyError) as error:
            LOGGER.error(error)
            LOGGER.error("Getting task progress failed due to internal error.")
            return DATABASE_QUERY_ERROR, result

    def _get_processed_task_progress(self, data):
        """
        Get each task's progress
        Args:
            data (dict): task list info

        Returns:
            int: status code
            dict: query result
        """
        task_list = data["task_list"]
        username = data["username"]
        cve_task, repo_task = self._split_task_list(username, task_list)
        cve_task_progress = self._get_cve_task_progress(cve_task)
        repo_task_progress = self._get_repo_task_progress(repo_task)

        result = {}
        result.update(cve_task_progress)
        result.update(repo_task_progress)

        succeed_list = list(result.keys())
        fail_list = list(set(task_list) - set(succeed_list))
        if fail_list:
            LOGGER.debug("No data found when getting the progress of task: %s." % fail_list)

        status_dict = {"succeed_list": succeed_list, "fail_list": fail_list}
        status_code = judge_return_code(status_dict, NO_DATA)
        return status_code, {"result": result}

    def _split_task_list(self, username, task_list):
        """
        split task list based on task's type
        Args:
            username (str): user name
            task_list (list): task id list

        Returns:
            list: cve task list
            list: repo task list
        """
        cve_task = []
        repo_task = []

        # filter task's type in case of other type added into task table
        task_query = self.session.query(Task.task_id, Task.task_type) \
            .filter(Task.username == username, Task.task_id.in_(task_list),
                    Task.task_type.in_(task_types))

        for row in task_query:
            if row.task_type == "cve":
                cve_task.append(row.task_id)
            else:
                repo_task.append(row.task_id)
        return cve_task, repo_task

    @staticmethod
    def _get_status_result():
        def status_dict():
            return {"succeed": 0, "fail": 0, "running": 0, "unknown": 0}

        return defaultdict(status_dict)

    def _get_cve_task_progress(self, task_list):
        """
        get cve tasks' progress
        Args:
            task_list (list): cve tasks' id list

        Returns:
            dict: e.g.
                {"task1": {"succeed": 1, "fail": 0, "running": 10, "unknown": 1}}

        Raises:
            KeyError
        """

        def defaultdict_set():
            return defaultdict(set)

        tasks_dict = defaultdict(defaultdict_set)
        result = self._get_status_result()

        task_query = self._query_cve_task_host_status(task_list)
        for row in task_query:
            tasks_dict[row.task_id][row.host_id].add(row.status)

        for task_id, hosts_dict in tasks_dict.items():
            for host_id, status_set in hosts_dict.items():
                host_status = self._get_cve_task_status(status_set)
                result[task_id][host_status] += 1

        succeed_list = list(result.keys())
        fail_list = list(set(task_list) - set(succeed_list))
        if fail_list:
            LOGGER.error("CVE task '%s' exist but status data is not record." % fail_list)
        return result

    def _query_cve_task_host_status(self, task_list):
        """
        query host and CVE's relationship and status of required tasks
        Args:
            task_list (list): task id list

        Returns:
            sqlalchemy.orm.query.Query
        """
        task_query = self.session.query(TaskCveHostAssociation.task_id,
                                        TaskCveHostAssociation.host_id,
                                        TaskCveHostAssociation.status) \
            .filter(TaskCveHostAssociation.task_id.in_(task_list))
        return task_query

    @staticmethod
    def _get_cve_task_status(status_set):
        """
        get cve task's host or cve's overall status
        Args:
            status_set (set): host or cve's status set

        Returns:
            str
        """
        if "running" in status_set:
            return "running"
        if "unknown" in status_set:
            return "unknown"
        if "unfixed" in status_set:
            return "fail"
        return "succeed"

    def _get_repo_task_progress(self, task_list):
        """
        get repo tasks' progress
        Args:
            task_list (list): repo tasks' id list

        Returns:
            dict: e.g.
                {"task1": {"succeed": 1, "fail": 0, "running": 10, "unknown": 1}}

        Raises:
            KeyError
        """
        result = self._get_status_result()

        task_query = self._query_repo_task_host(task_list)
        for row in task_query:
            if row.status == "set":
                result[row.task_id]["succeed"] += 1
            elif row.status == "unset":
                result[row.task_id]["fail"] += 1
            elif row.status == "running":
                result[row.task_id]["running"] += 1
            elif row.status == "unknown":
                result[row.task_id]["unknown"] += 1
            else:
                LOGGER.error("Unknown repo task's host status '%s'" % row.status)

        succeed_list = list(result.keys())
        fail_list = list(set(task_list) - set(succeed_list))
        if fail_list:
            LOGGER.error("Repo task '%s' exist but status data is not record." % fail_list)
        return result

    def _query_repo_task_host(self, task_list):
        """
        query host and CVE's relationship and status of required tasks
        Args:
            task_list (list): task id list

        Returns:
            sqlalchemy.orm.query.Query
        """
        task_query = self.session.query(TaskHostRepoAssociation.task_id,
                                        TaskHostRepoAssociation.status) \
            .filter(TaskHostRepoAssociation.task_id.in_(task_list))
        return task_query

    def get_task_info(self, data):
        """
        Get a task's info
        Args:
            data (dict): parameter, e.g.
                {
                    "task_id": "id1",
                    "username": "admin"
                }
        Returns:
            int: status code
            dict: query result. e.g.
                {
                    "result": {
                        "task_name": "task",
                        "description": "a long description",
                        "host_num": 2,
                        "need_reboot": 1,
                        "auto_reboot": True,
                        "latest_execute_time": 1111111111
                    }
                }
        """
        result = {}
        try:
            status_code, result = self._get_processed_task_info(data)
            LOGGER.debug("Finished getting task info.")
            return status_code, result
        except SQLAlchemyError as error:
            LOGGER.error(error)
            LOGGER.error("Getting task info failed due to internal error.")
            return DATABASE_QUERY_ERROR, result

    def _get_processed_task_info(self, data):
        """
        query and process task info
        Args:
            data (dict): task id info

        Returns:
            int: status code
            dict: query result
        """
        task_id = data["task_id"]
        username = data["username"]

        task_info_query = self._query_task_info_from_mysql(username, task_id)
        if not task_info_query.all():
            LOGGER.debug("No data found when getting the info of task: %s." % task_id)
            return NO_DATA, {"result": {}}

        # raise exception when multiple record found
        task_info_data = task_info_query.one()

        info_dict = self._task_info_row2dict(task_info_data)
        return SUCCEED, {"result": info_dict}

    def _query_task_info_from_mysql(self, username, task_id):
        """
        query needed task info
        Args:
            username (str): user name of the request
            task_id (str): task id

        Returns:
            sqlalchemy.orm.query.Query
        """
        task_info_query = self.session.query(Task.task_name, Task.description, Task.host_num,
                                             Task.need_reboot, Task.auto_reboot,
                                             Task.latest_execute_time) \
            .filter(Task.task_id == task_id, Task.username == username)
        return task_info_query

    @staticmethod
    def _task_info_row2dict(row):
        task_info = {
            "task_name": row.task_name,
            "description": row.description,
            "host_num": row.host_num,
            "need_reboot": row.need_reboot,
            "auto_reboot": row.auto_reboot,
            "latest_execute_time": row.latest_execute_time
        }
        return task_info

    def get_cve_task_info(self, data):
        """
        Get the specific info about the cve fixing task.

        Args:
            data (dict): parameter, e.g.
                {
                    "task_id": "id1",
                    "sort": "host_num",
                    "direction": "asc",
                    "page": 1,
                    "per_page": 10,
                    "username": "admin",
                    "filter": {
                        "cve_id": "",
                        "reboot": True,
                        "status": []
                    }
                }

        Returns:
            int: status code
            dict: task's cve info. e.g.
                {
                    "total_count": 1,
                    "total_page": 1,
                    "result": [{
                        "cve_id": "id1",
                        "package": "tensorflow",
                        "reboot": True,
                        "host_num": 3,
                        "status": "running"
                    }]
                }
        """
        result = {}
        try:
            result = self._get_processed_cve_task(data)
            LOGGER.debug("Finished getting task's cve info.")
            return SUCCEED, result
        except SQLAlchemyError as error:
            LOGGER.error(error)
            LOGGER.error("Getting task's cve info failed due to internal error.")
            return DATABASE_QUERY_ERROR, result

    def _get_processed_cve_task(self, data):
        """
        Query and process cve task's cve info
        Args:
            data (dict): query condition

        Returns:
            int: status code
            dict
        """
        result = {
            "total_count": 0,
            "total_page": 1,
            "result": []
        }

        task_id = data["task_id"]
        filter_dict = data.get("filter", {})
        filters = self._get_cve_task_filters(filter_dict)
        task_cve_query = self._query_cve_task(data["username"], task_id, filters)
        cve_info_list = self._process_cve_task_data(task_cve_query, filter_dict)

        total_count = len(cve_info_list)
        # NO_DATA code is NOT returned because no data situation here is normal with filter
        if not total_count:
            return result

        processed_result, total_page = self._sort_and_page_task_cve(cve_info_list, data)
        result['result'] = processed_result
        result['total_page'] = total_page
        result['total_count'] = total_count

        return result

    @staticmethod
    def _get_cve_task_filters(filter_dict):
        """
        Generate filters to filter cve task's cve info
        (filter by status will be manually implemented)
        Args:
            filter_dict(dict): filter dict to filter cve task's cve info, e.g.
                {
                    "cve_id": "",
                    "reboot": True,
                    "status": [""]
                }

        Returns:
            set
        """
        filters = set()

        if filter_dict.get("cve_id"):
            filters.add(Cve.cve_id.like("%" + filter_dict["cve_id"] + "%"))
        if filter_dict.get("reboot"):
            filters.add(Cve.reboot == filter_dict["reboot"])
        return filters

    def _query_cve_task(self, username, task_id, filters):
        """
        query needed cve task's cve info
        Args:
            username (str): user name of the request
            task_id (str): task id
            filters (set): filter given by user

        Returns:
            sqlalchemy.orm.query.Query. row structure:
                {
                    "cve_id": "CVE-2021-0001",
                    "package": "tensorflow",
                    "reboot": True,
                    "host_id": "id1",
                    "status": "fixed"
                }
        """
        task_cve_query = self.session.query(Cve.cve_id, Cve.reboot, CveAffectedPkgs.package,
                                            TaskCveHostAssociation.host_id,
                                            TaskCveHostAssociation.status) \
            .join(TaskCveHostAssociation, TaskCveHostAssociation.cve_id == Cve.cve_id) \
            .join(CveAffectedPkgs, CveAffectedPkgs.cve_id == Cve.cve_id) \
            .join(Task, Task.task_id == TaskCveHostAssociation.task_id) \
            .filter(Task.task_id == task_id, Task.username == username) \
            .filter(*filters)

        return task_cve_query

    def _process_cve_task_data(self, task_cve_query, filter_dict):
        """
        process task cve query data, get each cve's total status and host_num, then filter by status
        Args:
            task_cve_query (sqlalchemy.orm.query.Query): query result of cve task's cve info
                each row's structure:
                    {
                        "cve_id": "CVE-2021-0001",
                        "package": "tensorflow",
                        "reboot": True,
                        "host_id": "id1",
                        "status": "fixed"
                    }
            filter_dict (None/dict): the status user want
        Returns:
            list. e.g.
                [{
                    "cve_id": "CVE-2021-0001",
                    "package": "tensorflow",
                    "reboot": True,
                    "host_num": 3,
                    "status": "running"
                }]
        """
        need_status = filter_dict.get("status") if filter_dict else None
        cve_info_list = []
        cve_dict = {}

        for row in task_cve_query:
            cve_id = row.cve_id
            if cve_id not in cve_dict:
                cve_dict[cve_id] = {"package": {row.package}, "reboot": row.reboot,
                                    "host_set": {row.host_id}, "status_set": {row.status}}
            else:
                cve_dict[cve_id]["package"].add(row.package)
                cve_dict[cve_id]["host_set"].add(row.host_id)
                cve_dict[cve_id]["status_set"].add(row.status)

        if isinstance(need_status, list):
            if not need_status:
                return cve_info_list
            for cve_id, cve_info in cve_dict.items():
                cve_status = self._get_cve_task_status(cve_info.pop("status_set"))
                if cve_status in need_status:
                    cve_info["cve_id"] = cve_id
                    cve_info["package"] = ','.join(list(cve_info["package"]))
                    cve_info["host_num"] = len(cve_info.pop("host_set"))
                    cve_info["status"] = cve_status
                    cve_info_list.append(cve_info)
        else:
            for cve_id, cve_info in cve_dict.items():
                cve_info["cve_id"] = cve_id
                cve_info["package"] = ','.join(list(cve_info["package"]))
                cve_info["host_num"] = len(cve_info.pop("host_set"))
                cve_info["status"] = self._get_cve_task_status(cve_info.pop("status_set"))
                cve_info_list.append(cve_info)
        return cve_info_list

    @staticmethod
    def _sort_and_page_task_cve(cve_info_list, data):
        """
        sort and page cve task's cve info
        Args:
            cve_info_list (list): cve task's cve info list. not empty.
            data (dict): parameter, e.g.
                {
                    "task_id": "id1",
                    "sort": "host_num",
                    "direction": "asc",
                    "page": 1,
                    "per_page": 10,
                    "username": "admin",
                    "filter": {
                        "cve_id": "",
                        "reboot": True,
                        "status": []
                    }
                }

        Returns:
            list: sorted cve info list
            int: total page
        """
        page = data.get('page')
        per_page = data.get('per_page')
        reverse = False
        if data.get("sort") == "host_num" and data.get("direction") == "desc":
            reverse = True

        total_page = 1
        total_count = len(cve_info_list)

        cve_info_list.sort(key=lambda cve_info: cve_info["host_num"], reverse=reverse)

        if page and per_page:
            total_page = math.ceil(total_count / per_page)
            return cve_info_list[per_page*(page-1): per_page*page], total_page

        return cve_info_list, total_page
